#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "../Include/point.h"

/**
 * 训练数据集的行数、列数
 */
extern int TRAINING_DATASET_ROW_SIZE = 10;
extern int TRAINING_DATASET_COLUMN_SIZE = 5;

/**
 * 测试数据集的行数、列数
 */
extern int TESTING_DATASET_ROW_SIZE = 10;
extern int TESTING_DATASET_COLUMN_SIZE = 5;

void Classify(double normalTrainingDataset[TRAINING_DATASET_ROW_SIZE][TRAINING_DATASET_COLUMN_SIZE],
                       double normalTrainingDatasetLabel[TRAINING_DATASET_ROW_SIZE],
                       double normalTestingDataset[TESTING_DATASET_ROW_SIZE][TESTING_DATASET_COLUMN_SIZE],
                       double normalTestingDatasetLabel[TESTING_DATASET_ROW_SIZE], struct Point *allDistanceList[TESTING_DATASET_ROW_SIZE][TRAINING_DATASET_ROW_SIZE]) {

    //计算每一个测试数据集中的样本距离每一个训练数据集中样本的欧式距离。allDistanceList
    // 的第一行表示测试数据集的第一个样本距离训练数据集中每个样本的距离
//    struct Point *allDistanceList[TESTING_DATASET_ROW_SIZE][TRAINING_DATASET_ROW_SIZE];
    for (int i = 0; i < TESTING_DATASET_ROW_SIZE; ++i) {
        for (int j = 0; j < TRAINING_DATASET_ROW_SIZE; ++j) {
            allDistanceList[i][j] = (struct Point *) malloc(sizeof(struct Point *));
        }
    }
//    *allDistanceList = (struct Point *)malloc(sizeof(struct Point *) * TESTING_DATASET_ROW_SIZE * TRAINING_DATASET_ROW_SIZE);

    for (int i = 0; i < TESTING_DATASET_ROW_SIZE; ++i) {

        // 表示测试数据集中的某一条样本距离训练数据集中所有样本的距离
        struct Point *distanceList[TRAINING_DATASET_ROW_SIZE];
        for (int y = 0; y < TRAINING_DATASET_ROW_SIZE; ++y) {
            distanceList[y] = (struct Point *)malloc(sizeof(struct Point));
        }
        for (int j = 0; j < TESTING_DATASET_ROW_SIZE; ++j) {
            double pow_sum = 0.0;
            double sqrt_sum;
            for (int k = 0; k < TESTING_DATASET_COLUMN_SIZE; ++k) {
                double jk = normalTrainingDataset[j][k];
                double ik = normalTestingDataset[i][k];
                pow_sum += pow(jk - ik, 2);
            }
            sqrt_sum = sqrt(pow_sum);
            struct Point *point = (struct Point *)malloc(sizeof(struct Point));
            point->testingDatasetIndex = i;
            point->trainingDatasetIndex = j;
            point->distance = sqrt_sum;
            distanceList[j] = point;
        }
        // 降序排列
        for (int m = 0; m < TESTING_DATASET_ROW_SIZE; m++) {
            int max = m;
            for (int n = m + 1; n < TRAINING_DATASET_ROW_SIZE; n++) {
                if (distanceList[max]->distance < distanceList[n]->distance) {
                    max = n;
                    struct Point *temp = distanceList[m];
                    distanceList[m] = distanceList[max];
                    distanceList[max] = temp;
                }
            }
        }
        for (int x = 0; x < TRAINING_DATASET_ROW_SIZE; ++x) {
            allDistanceList[i][x] = distanceList[x];
        }
    }

//    return allDistanceList;

//    struct Point *p1=(struct Point *)malloc(sizeof(struct Point *));
//    p1->trainingDatasetIndex=1;
//    p1->testingDatasetIndex=2;
//    p1->distance=3;
//    p[0][0]=p1;
//    struct Point *p2=(struct Point *)malloc(sizeof(struct Point *));
//    p2->trainingDatasetIndex=1;
//    p2->testingDatasetIndex=2;
//    p2->distance=4;
//    p[0][1]=p2;
//    return p;
}